{ "cells": [ { "cell_type": "markdown", "id": "69ef9e14", "metadata": {}, "source": [ "## **Single Stage -- Paradigm 1**\n", "\n", "### Real Data 1. Movie Lens\n", "\n", "Movie Lens is a movie recommendation website that helps users to find movies and collect their ratings. The goal of the simulation studies in single stage causal effect learning is to infer on the causal effect of treating users 'Drama', versus the control movie genere 'Sci-Fi'. This serves as an offline evaluation of how well people like/dislike a specific movie genere versus the other, and hence provides us a general scope of which movie genere to recommend so as to maximize users' satisfaction.\n" ] }, { "cell_type": "markdown", "id": "Vx3GPf3t1Eo3", "metadata": { "id": "Vx3GPf3t1Eo3" }, "source": [ "#### Data Pre-processing" ] }, { "cell_type": "code", "execution_count": 1, "id": "21378417", "metadata": {}, "outputs": [], "source": [ "# import related packages\n", "import os\n", "import pickle\n", "import numpy as np\n", "\n", "from causaldm.learners.CPL4.CMAB import _env_realCMAB as _env\n", "data = _env.get_movielens()" ] }, { "cell_type": "code", "execution_count": 2, "id": "66804173", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['Individual', 'Xs', 'mean_ri', 'standardized_Xs'])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.keys()" ] }, { "cell_type": "code", "execution_count": 3, "id": "21dded94", "metadata": {}, "outputs": [], "source": [ "data_ML = data['Individual']" ] }, { "cell_type": "code", "execution_count": 4, "id": "5988cfc7", "metadata": {}, "outputs": [], "source": [ "userinfo_index = np.array([3,9,11,12,13,14])\n", "\n", "users_index = data_ML.keys()\n", "n = len(users_index) # the number of users\n", "movie_generes = ['Comedy', 'Drama', 'Action', 'Thriller', 'Sci-Fi']\n", "\n", "data_CEL = {}\n", " \n", "# initialize the final data we'll use in Causal Effect Learning\n", "for i in movie_generes:\n", " data_CEL[i] = None \n", "\n", "import pandas as pd\n", "for movie_genere in movie_generes:\n", " for user in users_index:\n", " data_CEL[movie_genere] = pd.concat([data_CEL[movie_genere] , data_ML[user][movie_genere]['complete']])\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "a4b8fa79", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | user_id | \n", "movie_id | \n", "rating | \n", "age | \n", "Comedy | \n", "Drama | \n", "Action | \n", "Thriller | \n", "Sci-Fi | \n", "gender_M | \n", "occupation_academic/educator | \n", "occupation_college/grad student | \n", "occupation_executive/managerial | \n", "occupation_other | \n", "occupation_technician/engineer | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
4220 | \n", "48 | \n", "2355.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
14400 | \n", "48 | \n", "2918.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
16752 | \n", "48 | \n", "2791.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
20195 | \n", "48 | \n", "2797.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
21689 | \n", "48 | \n", "2321.0 | \n", "3.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
393463 | \n", "5878.0 | \n", "3299.0 | \n", "3.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
395410 | \n", "5878.0 | \n", "892.0 | \n", "5.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
396058 | \n", "5878.0 | \n", "574.0 | \n", "1.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
397794 | \n", "5878.0 | \n", "1812.0 | \n", "5.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
400719 | \n", "5878.0 | \n", "3830.0 | \n", "1.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
49563 rows × 15 columns
\n", "\n", " | user_id | \n", "movie_id | \n", "rating | \n", "age | \n", "Drama | \n", "gender_M | \n", "occupation_academic/educator | \n", "occupation_college/grad student | \n", "occupation_executive/managerial | \n", "occupation_other | \n", "occupation_technician/engineer | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
14 | \n", "48 | \n", "1193.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
11057 | \n", "48 | \n", "919.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
25871 | \n", "48 | \n", "527.0 | \n", "5.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
31166 | \n", "48 | \n", "1721.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
40383 | \n", "48 | \n", "150.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
303406 | \n", "5878.0 | \n", "3300.0 | \n", "2.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
320275 | \n", "5878.0 | \n", "1391.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
332011 | \n", "5878.0 | \n", "185.0 | \n", "4.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
382221 | \n", "5878.0 | \n", "2232.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
397209 | \n", "5878.0 | \n", "426.0 | \n", "3.0 | \n", "25.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65642 rows × 11 columns
\n", "\n", " | bloc | \n", "icustayid | \n", "charttime | \n", "gender | \n", "age | \n", "elixhauser | \n", "re_admission | \n", "died_in_hosp | \n", "died_within_48h_of_out_time | \n", "mortality_90d | \n", "... | \n", "input_total | \n", "input_4hourly | \n", "output_total | \n", "output_4hourly | \n", "cumulated_balance | \n", "SOFA | \n", "SIRS | \n", "vaso_input | \n", "iv_input | \n", "reward | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1 | \n", "3 | \n", "7245486000 | \n", "0 | \n", "17639.826435 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "... | \n", "6527.0000 | \n", "50.0 | \n", "13617.0 | \n", "520.0 | \n", "-7090.0000 | \n", "5 | \n", "1 | \n", "0.0 | \n", "2.0 | \n", "-0.884898 | \n", "
1 | \n", "1 | \n", "11 | \n", "6898241400 | \n", "1 | \n", "30766.069028 | \n", "6 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0.0000 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0000 | \n", "12 | \n", "0 | \n", "0.0 | \n", "0.0 | \n", "0.383136 | \n", "
2 | \n", "1 | \n", "12 | \n", "5805732000 | \n", "1 | \n", "12049.217303 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0.0000 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0000 | \n", "4 | \n", "2 | \n", "0.0 | \n", "0.0 | \n", "0.976040 | \n", "
3 | \n", "1 | \n", "14 | \n", "4264269300 | \n", "0 | \n", "30946.970000 | \n", "2 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "... | \n", "1300.0000 | \n", "1300.0 | \n", "340.0 | \n", "160.0 | \n", "960.0000 | \n", "5 | \n", "2 | \n", "0.0 | \n", "4.0 | \n", "0.125000 | \n", "
4 | \n", "1 | \n", "30 | \n", "5707825200 | \n", "0 | \n", "19793.588912 | \n", "6 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "9552.0000 | \n", "50.0 | \n", "6830.0 | \n", "540.0 | \n", "2722.0000 | \n", "6 | \n", "2 | \n", "0.0 | \n", "2.0 | \n", "0.457625 | \n", "
5 | \n", "1 | \n", "33 | \n", "7214122800 | \n", "0 | \n", "24524.747419 | \n", "5 | \n", "0 | \n", "1 | \n", "1 | \n", "1 | \n", "... | \n", "10661.0483 | \n", "725.0 | \n", "5746.0 | \n", "360.0 | \n", "4915.0483 | \n", "4 | \n", "0 | \n", "0.0 | \n", "4.0 | \n", "1.049099 | \n", "
6 rows × 62 columns
\n", "\n", " | Glucose | \n", "paO2 | \n", "PaO2_FiO2 | \n", "iv_input | \n", "SOFA | \n", "reward | \n", "
---|---|---|---|---|---|---|
0 | \n", "84.000000 | \n", "84.000000 | \n", "168.000000 | \n", "2.0 | \n", "5 | \n", "-0.884898 | \n", "
1 | \n", "122.000000 | \n", "59.444444 | \n", "198.148148 | \n", "0.0 | \n", "12 | \n", "0.383136 | \n", "
2 | \n", "125.000000 | \n", "192.000000 | \n", "690.647482 | \n", "0.0 | \n", "4 | \n", "0.976040 | \n", "
3 | \n", "110.727273 | \n", "179.000000 | \n", "447.499993 | \n", "4.0 | \n", "5 | \n", "0.125000 | \n", "
4 | \n", "187.000000 | \n", "125.000000 | \n", "347.222222 | \n", "2.0 | \n", "6 | \n", "0.457625 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
4995 | \n", "121.375000 | \n", "136.787683 | \n", "206.005547 | \n", "3.0 | \n", "4 | \n", "-1.965110 | \n", "
4996 | \n", "108.000000 | \n", "62.333333 | \n", "143.846153 | \n", "0.0 | \n", "11 | \n", "-0.025000 | \n", "
4997 | \n", "106.000000 | \n", "258.500000 | \n", "923.214286 | \n", "0.0 | \n", "7 | \n", "0.402531 | \n", "
4998 | \n", "144.000000 | \n", "376.000000 | \n", "752.000000 | \n", "1.0 | \n", "4 | \n", "-0.172130 | \n", "
4999 | \n", "113.000000 | \n", "108.000000 | \n", "269.999996 | \n", "4.0 | \n", "5 | \n", "-0.025000 | \n", "
5000 rows × 6 columns
\n", "\n", " | Glucose | \n", "paO2 | \n", "PaO2_FiO2 | \n", "iv_input | \n", "SOFA | \n", "reward | \n", "
---|---|---|---|---|---|---|
0 | \n", "1.0 | \n", "1.000000 | \n", "1.000000 | \n", "1.0 | \n", "1 | \n", "1.000000 | \n", "
1 | \n", "122.0 | \n", "59.444444 | \n", "198.148148 | \n", "0.0 | \n", "12 | \n", "0.383136 | \n", "
2 | \n", "125.0 | \n", "192.000000 | \n", "690.647482 | \n", "0.0 | \n", "4 | \n", "0.976040 | \n", "
3 | \n", "1.0 | \n", "1.000000 | \n", "1.000000 | \n", "1.0 | \n", "1 | \n", "1.000000 | \n", "
4 | \n", "1.0 | \n", "1.000000 | \n", "1.000000 | \n", "1.0 | \n", "1 | \n", "1.000000 | \n", "
5 | \n", "1.0 | \n", "1.000000 | \n", "1.000000 | \n", "1.0 | \n", "1 | \n", "1.000000 | \n", "